{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DropOut intuition"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Dropout from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's code our own dropout function first."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dropout(X, drop_prob):\n",
    "    assert 0 <= drop_prob <= 1\n",
    "    # In this case, all elements are dropped out\n",
    "    if drop_prob == 1:\n",
    "        return torch.zeros_like(X)\n",
    "    mask = torch.rand(X.shape).uniform_() > drop_prob\n",
    "    return mask.float() * X / (1.0-drop_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.arange(16).view(2,8).float()\n",
    "print(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(dropout(X, 0))\n",
    "print(dropout(X, 0.5))\n",
    "print(dropout(X, 1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. DropOut on a toy dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_SAMPLES = 20\n",
    "N_HIDDEN = 300"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's generate some training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# training data\n",
    "x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)\n",
    "y = x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))\n",
    "\n",
    "# test data\n",
    "test_x = torch.unsqueeze(torch.linspace(-1, 1, N_SAMPLES), 1)\n",
    "test_y = test_x + 0.3*torch.normal(torch.zeros(N_SAMPLES, 1), torch.ones(N_SAMPLES, 1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# show data\n",
    "plt.scatter(x.data.numpy(), y.data.numpy(), c='green', s=50, alpha=0.5, label='train')\n",
    "plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='orange', s=50, alpha=0.5, label='test')\n",
    "plt.legend(loc='upper left')\n",
    "plt.ylim((-2.5, 2.5))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Exercise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create a network `net_simple` as `nn.Sequential` with the following structure: `Fully Connected N_HIDDEN -> ReLU -> Fully Connected N_HIDDEN -> ReLU -> Fully Connected 1`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_simple = torch.nn.Sequential(\n",
    "#     TODO\n",
    ")\n",
    "\n",
    "print(net_simple)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Exercise\n",
    "\n",
    "Now define the same architecture with `Dropout` layers in-between with $p=0.5$. Where will you place them?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net_dropout = torch.nn.Sequential(\n",
    "#    TODO\n",
    ")\n",
    "\n",
    "print(net_dropout)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_simple = torch.optim.Adam(net_simple.parameters(), lr=0.01)\n",
    "optimizer_dropout = torch.optim.Adam(net_dropout.parameters(), lr=0.01)\n",
    "loss_fn = torch.nn.MSELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(500):\n",
    "    pred_simple = net_simple(x)\n",
    "    pred_drop = net_dropout(x)\n",
    "    loss_simple = loss_fn(pred_simple, y)\n",
    "    loss_dropout = loss_fn(pred_drop, y)\n",
    "\n",
    "    optimizer_simple.zero_grad()\n",
    "    optimizer_dropout.zero_grad()\n",
    "    loss_simple.backward()\n",
    "    loss_dropout.backward()\n",
    "    optimizer_simple.step()\n",
    "    optimizer_dropout.step()\n",
    "\n",
    "    if epoch % 100 == 0:\n",
    "        # change to eval mode in order to fix drop out effect\n",
    "        net_simple.eval()\n",
    "        net_dropout.eval()  # parameters for dropout differ from train mode\n",
    "\n",
    "        # plotting\n",
    "        plt.cla()\n",
    "        test_pred_simple = net_simple(test_x)\n",
    "        test_pred_dropout = net_dropout(test_x)\n",
    "        plt.scatter(x.data.numpy(), y.data.numpy(), c='green', s=50, alpha=0.3, label='train')\n",
    "        plt.scatter(test_x.data.numpy(), test_y.data.numpy(), c='orange', s=50, alpha=0.3, label='test')\n",
    "        plt.plot(test_x.data.numpy(), test_pred_simple.data.numpy(), 'r-', lw=3, label='overfitting')\n",
    "        plt.plot(test_x.data.numpy(), test_pred_dropout.data.numpy(), 'b--', lw=3, label='dropout(50%)')\n",
    "        plt.text(0, -1.2, 'overfitting test loss=%.4f' % loss_fn(test_pred_simple, test_y).data.item(), fontdict={'size': 16, 'color':  'red'})\n",
    "        plt.text(0, -1.5, 'dropout test loss=%.4f' % loss_fn(test_pred_dropout, test_y).data.item(), fontdict={'size': 16, 'color': 'blue'})\n",
    "        plt.legend(loc='upper left'); plt.ylim((-2.5, 2.5));plt.pause(0.1)\n",
    "\n",
    "        # change back to train mode\n",
    "        net_simple.train()\n",
    "        net_dropout.train()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
